-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][Transforms] ConversionPatternRewriter
: Add config
getter
#152310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesAdd a helper function to Also remove the Full diff: https://github.com/llvm/llvm-project/pull/152310.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6437657c9a93..4e651a0489899 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
public:
~ConversionPatternRewriter() override;
+ /// Return the configuration of the current dialect conversion.
+ const ConversionConfig &getConfig() const;
+
/// Apply a signature conversion to given block. This replaces the block with
/// a new block containing the updated signature. The operations of the given
/// block are inlined into the newly-created block, which is returned.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..a55da79455010 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1754,6 +1754,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
+const ConversionConfig &ConversionPatternRewriter::getConfig() const {
+ return impl->config;
+}
+
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
@@ -1895,7 +1899,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// ops should be moved one-by-one ("slow path"), so that a separate
// `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
// a bit more efficient, so we try to do that when possible.
- bool fastPath = !impl->config.listener;
+ bool fastPath = !getConfig().listener;
if (fastPath)
impl->inlineBlockBefore(source, dest, before);
@@ -2018,8 +2022,7 @@ class OperationLegalizer {
using LegalizationAction = ConversionTarget::LegalizationAction;
OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns,
- const ConversionConfig &config);
+ const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
@@ -2116,16 +2119,12 @@ class OperationLegalizer {
/// The pattern applicator to use for conversions.
PatternApplicator applicator;
-
- /// Dialect conversion configuration.
- const ConversionConfig &config;
};
} // namespace
OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns,
- const ConversionConfig &config)
- : target(targetInfo), applicator(patterns), config(config) {
+ const FrozenRewritePatternSet &patterns)
+ : target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2286,7 +2285,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
- if (!config.allowPatternRollback) {
+ if (!rewriter.getConfig().allowPatternRollback) {
// Rolling back a folder is like rolling back a pattern.
llvm::report_fatal_error(
"op '" + opName +
@@ -2306,6 +2305,7 @@ LogicalResult
OperationLegalizer::legalizeWithPattern(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
+ const ConversionConfig &config = rewriter.getConfig();
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
Operation *checkOp;
@@ -2749,8 +2749,7 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : config(config), opLegalizer(target, patterns, this->config),
- mode(mode) {}
+ : config(config), opLegalizer(target, patterns), mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd a helper function to Also remove the Full diff: https://github.com/llvm/llvm-project/pull/152310.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6437657c9a93..4e651a0489899 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter {
public:
~ConversionPatternRewriter() override;
+ /// Return the configuration of the current dialect conversion.
+ const ConversionConfig &getConfig() const;
+
/// Apply a signature conversion to given block. This replaces the block with
/// a new block containing the updated signature. The operations of the given
/// block are inlined into the newly-created block, which is returned.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f23c6197accd5..a55da79455010 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1754,6 +1754,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
+const ConversionConfig &ConversionPatternRewriter::getConfig() const {
+ return impl->config;
+}
+
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
assert(op && newOp && "expected non-null op");
replaceOp(op, newOp->getResults());
@@ -1895,7 +1899,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// ops should be moved one-by-one ("slow path"), so that a separate
// `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
// a bit more efficient, so we try to do that when possible.
- bool fastPath = !impl->config.listener;
+ bool fastPath = !getConfig().listener;
if (fastPath)
impl->inlineBlockBefore(source, dest, before);
@@ -2018,8 +2022,7 @@ class OperationLegalizer {
using LegalizationAction = ConversionTarget::LegalizationAction;
OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns,
- const ConversionConfig &config);
+ const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
@@ -2116,16 +2119,12 @@ class OperationLegalizer {
/// The pattern applicator to use for conversions.
PatternApplicator applicator;
-
- /// Dialect conversion configuration.
- const ConversionConfig &config;
};
} // namespace
OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns,
- const ConversionConfig &config)
- : target(targetInfo), applicator(patterns), config(config) {
+ const FrozenRewritePatternSet &patterns)
+ : target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2286,7 +2285,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
- if (!config.allowPatternRollback) {
+ if (!rewriter.getConfig().allowPatternRollback) {
// Rolling back a folder is like rolling back a pattern.
llvm::report_fatal_error(
"op '" + opName +
@@ -2306,6 +2305,7 @@ LogicalResult
OperationLegalizer::legalizeWithPattern(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
+ const ConversionConfig &config = rewriter.getConfig();
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
Operation *checkOp;
@@ -2749,8 +2749,7 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : config(config), opLegalizer(target, patterns, this->config),
- mode(mode) {}
+ : config(config), opLegalizer(target, patterns), mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
|
Here's an example where a pattern would check the |
config
getterConversionPatternRewriter
: Add config
getter
Add a helper function to
ConversionPatternRewriter
that returns the dialect conversion configuration. This flag is useful when migrating conversion patterns to the new One-Shot Conversion Driver: patterns can check if they are running in rollback mode or not. They can then work around API changes and makes sure that the pattern keeps working with both the old and new driver.Also remove the
config
field fromOperationLegalizer
. That field was never needed.